{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "dataset_filename = \"affinity_dataset.txt\"\n",
    "X = np.loadtxt(dataset_filename)\n",
    "n_samples, n_features= X.shape\n",
    "features = [\"bread\", \"milk\", \"cheese\", \"apples\", \"bananas\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "valid_rules = defaultdict(int)\n",
    "invalid_rules = defaultdict(int)\n",
    "num_occurences = defaultdict(int)\n",
    "for sample in X:\n",
    "    for premise in range(n_features):\n",
    "        if sample[premise] == 0: continue # come back to the loop head\n",
    "        num_occurences[premise] += 1\n",
    "        for conclusion in range(n_features):\n",
    "            if premise == conclusion:\n",
    "                continue\n",
    "            if sample[conclusion] == 1:\n",
    "                valid_rules[(premise, conclusion)] += 1\n",
    "            else:\n",
    "                invalid_rules[(premise, conclusion)] += 1\n",
    "support = valid_rules\n",
    "confidence = defaultdict(float)\n",
    "for premise, conclusion in valid_rules.keys():\n",
    "    rule = (premise, conclusion)\n",
    "    confidence[rule] = valid_rules[rule] / num_occurences[premise] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_rule(premise, conclusion, support, confidence, features):\n",
    "    premise_name = features[premise]\n",
    "    conclusion_name = features[conclusion]\n",
    "    print(\"Rule: If a person buys {0} they will also buy {1}\".format(premise_name, conclusion_name))\n",
    "    print(\" - Confidence: {0:.3f}\".format(confidence[(premise, conclusion)]))\n",
    "    print(\" - Support: {0}\".format(support[(premise, conclusion)]))\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "bread\n",
      "Rule: If a person buys milk they will also buy bananas\n",
      " - Confidence: 0.519\n",
      " - Support: 27\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(conclusion)\n",
    "print(features[conclusion])\n",
    "premise = 1\n",
    "conclusion = 4\n",
    "print_rule(premise, conclusion, support, confidence, features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rule #1\n",
      "Rule: If a person buys apples they will also buy bananas\n",
      " - Confidence: 0.628\n",
      " - Support: 27\n",
      "\n",
      "Rule #2\n",
      "Rule: If a person buys bananas they will also buy apples\n",
      " - Confidence: 0.474\n",
      " - Support: 27\n",
      "\n",
      "Rule #3\n",
      "Rule: If a person buys milk they will also buy bananas\n",
      " - Confidence: 0.519\n",
      " - Support: 27\n",
      "\n",
      "Rule #4\n",
      "Rule: If a person buys bananas they will also buy milk\n",
      " - Confidence: 0.474\n",
      " - Support: 27\n",
      "\n",
      "Rule #5\n",
      "Rule: If a person buys cheese they will also buy apples\n",
      " - Confidence: 0.564\n",
      " - Support: 22\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from operator import itemgetter\n",
    "sorted_support = sorted(support.items(), key = itemgetter(1), reverse = True)\n",
    "for index in range(5):\n",
    "    print(\"Rule #{0}\".format(index + 1))\n",
    "    (premise, conclusion) = sorted_support[index][0]\n",
    "    print_rule(premise, conclusion, support, confidence, features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rule #1\n",
      "Rule: If a person buys apples they will also buy bananas\n",
      " - Confidence: 0.628\n",
      " - Support: 27\n",
      "\n",
      "Rule #2\n",
      "Rule: If a person buys bread they will also buy bananas\n",
      " - Confidence: 0.571\n",
      " - Support: 16\n",
      "\n",
      "Rule #3\n",
      "Rule: If a person buys cheese they will also buy apples\n",
      " - Confidence: 0.564\n",
      " - Support: 22\n",
      "\n",
      "Rule #4\n",
      "Rule: If a person buys milk they will also buy bananas\n",
      " - Confidence: 0.519\n",
      " - Support: 27\n",
      "\n",
      "Rule #5\n",
      "Rule: If a person buys cheese they will also buy bananas\n",
      " - Confidence: 0.513\n",
      " - Support: 20\n",
      "\n"
     ]
    }
   ],
   "source": [
    "sorted_confidence = sorted(confidence.items(), key = itemgetter(1), reverse = True)\n",
    "for index in range(5):\n",
    "    print(\"Rule #{0}\".format(index + 1))\n",
    "    (premise, conclusion) = sorted_confidence[index][0]\n",
    "    print_rule(premise, conclusion, support, confidence, features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1f640af4a90>]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXhU9b3H8fd3MlkIWVgS9i2yCEH2EBBwaatWpIrWBbAqKopYl/Zqe7WbVdvHXtt6rbviikvFfa1La28VFxQSBCQssktkC3sCZP/dPzJqjBMSYGbOzOTzep48s5zfzPnkOH44OXMWc84hIiKxz+d1ABERCQ0VuohInFChi4jECRW6iEicUKGLiMQJv1czzsrKcr169fJq9iIiMamwsHCbcy472DTPCr1Xr14UFBR4NXsRkZhkZusbm6ZNLiIicUKFLiISJ1ToIiJxQoUuIhInVOgiInFChS4iEidU6CIicSLmCn3z7nJufLWIqppar6OIiESVmCv0hRt28thH67j9X597HUVEJKrEXKGffFRnJo/szn3vreaj1du8jiMiEjVirtABbjg1l5z2rbnmmUXs2lfpdRwRkagQk4WemuTnjsnD2L63gutf+AxdRk9EJEYLHWBQt0x+cdKRvFW0mWfmb/A6joiI52K20AEuPeYIxvZpz02vLWXV1jKv44iIeCqmC93nM/73nKGkJPr42exPqaiu8TqSiIhnYrrQATpmpHDrmYMp2riH2/6pXRlFpOVqVqGb2clmtsLMVpnZ9Y2MOd7MFppZkZm9F9qYB3bSwE6cN7oHM+es4f2VJZGctYhI1Giy0M0sAbgHGA/kAlPMLLfBmDbAvcBpzrmBwNlhyHpAvzkllz4d0rj22UVsL6uI9OxFRDzXnDX0fGCVc26Nc64SmA1MbDDmXOBF59wXAM65raGN2bRWSQncOXkYu/ZVcd0Li7Uro4i0OM0p9K5A/f0CiwPP1dcPaGtm75pZoZldEOyNzGy6mRWYWUFJSeg3jeR2yeC68f15Z9lWnvzki5C/v4hINGtOoVuQ5xqu/vqBEcAE4IfA78ys33de5NxM51yecy4vOzvoRasP20VjenFcv2z++PpSPt9SGpZ5iIhEo+YUejHQvd7jbsDGIGPecs7tdc5tA+YAQ0IT8eD4fMZfzx5CWrKfq5/+lPIq7cooIi1Dcwp9PtDXzHLMLAmYDLzaYMwrwDFm5jezVGAUsCy0UZsvOz2Zv549hOWbS7n1reVexRARiagmC905Vw1cCbxNXUk/65wrMrMZZjYjMGYZ8BawGJgHPOScWxK+2E37Xv8OXDimF49+uI7/rIj4d7QiIhFnXu0NkpeX5woKCsI6j/KqGibe/SHb91bw5s+OJTs9OazzExEJNzMrdM7lBZsW80eKHkhKYgJ3ThnGnvJqfvn8Iu3KKCJxLa4LHeDITun8dsIA3l1RwmMfrfM6johI2MR9oQOcP7onP+jfgT+9sZxlm/Z4HUdEJCxaRKGbGX8+azCZqYnalVFE4laLKHSA9mnJ3Hb2EFZuLePm15dSVVPrdSQRkZDyex0gko7tl80l43J46IO1PF9QTN+OafTvlMGAzukM6JxB/07ptE/TnjAiEptaVKEDXD++P0O6t2HJl7tZtrmUOStLeGFB8dfTs9OTGdA5gwGd0ukfKPojstJI8reYP2ZEJEa1uEL3J/g4dUgXTh3S5evntpVVsGJzKcs27WHZplKWb97Dox9upzKwWSYxweidnUZu5wz6d05nSLc25Oe0wyzYaW5ERLzR4go9mKy0ZLL6JDO2T9bXz1XV1LJ2216WbdrD8kDZf7R6Oy9++iUAvzipH1d+v69XkUVEvkOF3ojEBB/9OqbTr2P6t07+vnNvJTe/vpS//vNzstOTmTSyh2cZRUTqU6EfpLatk/jzWYPZvreSX734Ge1bJ3NCbkevY4mItJzdFkMpMcHHfT8ZzqCumVzx9wUUrt/hdSQRERX6oWqd7OeRC0fSpU0rLn6sgJW6mIaIeEyFfhjapyXz+MX5JPl9XPDIPDbt3u91JBFpwVToh6l7u1Qeu2gkpeXVTH1kHrv3VXkdSURaKBV6CAzsksnMC0awbts+Lnl8vs4VIyKeUKGHyJjeWdw+aSgF63dy1dOfUq1zxYhIhKnQQ2jC4M7ceOpA/rV0C797ZYkuqCEiEaX90ENs6phebC0t557/rCY7LZlrTjrS60gi0kKo0MPgFycdSUlpBXf+3yqyM1I4f3RPryOJSAugQg8DM+OWMwaxvaySG15ZQlbrJMYP6ux1LBGJc9qGHib+BB93nzucYd3b8LPZC/l4zXavI4lInFOhh1GrpAQenjqSHu1TuXRWga5nKiJhpUIPs7atk5h1cT6tk/1MfWQeG3bs8zqSiMQpFXoEdG3TilkX51NeVcPUR+axY2+l15FEJA6p0CPkyE7pPDR1JMW79nPho/N4obCYRRt2UVquUwWISGiYVwe/5OXluYKCAk/m7aW3izbz89kL2V/v9ACdMlLo3aE1fbLT6N0h7evbDunJusydiHyLmRU65/KCTlOhR15VTS3rt+9j1dYyVpeUsfqr25K9lFVUfz0uPdnPEV8X/DeF37NdKv4E/XEl0hIdqNC1H7oHEhN89OmQRp8Oad963jnHlj0VXxf9V7fvryzhhQXFX4/zGST7E0hO9JHs95Hk95HsTyApwUdyoi9wm1Bv2lc/CST5fbRJTeS80T3JSEmM9K8uImGkQo8iZkanzBQ6ZaYwrm/Wt6btKa8KrMnv5Yvte9lfVUNFdS0VVbVU1tRSUV1DZXVt3XPVtezeX0VFVU3dtK/GBB6XV9Xyj8WbmHVxPllpyR79tiISair0GJGRksiwHm0Z1qPtYb/Xf1Zs5fInCznngbk8OW0UXdq0CkFCEfFaszbEmtnJZrbCzFaZ2fVBph9vZrvNbGHg54bQR5VQ+d6RHXj84lGU7Kng7PvnsnbbXq8jiUgINFnoZpYA3AOMB3KBKWaWG2To+865oYGfm0OcU0IsP6cdT08fzf6qGs6+fy5LN+ooVpFY15w19HxglXNujXOuEpgNTAxvLImEo7pm8uxlR5OYYEyeOZfC9Tu8jiQih6E5hd4V2FDvcXHguYaONrNFZvammQ0M9kZmNt3MCsysoKSk5BDiSqj16ZDGczOOpl3rJM57aB7vr9R/F5FY1ZxCD3ZkS8Od1xcAPZ1zQ4C7gJeDvZFzbqZzLs85l5ednX1wSSVsurVN5dkZR9OzfSrTHivgrSWbvY4kIoegOYVeDHSv97gbsLH+AOfcHudcWeD+G0CimX17vzuJah3SU3hm+tEc1TWDnz5VyHMFG5p+kYhEleYU+nygr5nlmFkSMBl4tf4AM+tkgWPUzSw/8L46AXiMyUxN5IlpoxjTO4tfPr+YRz5Y63UkETkITRa6c64auBJ4G1gGPOucKzKzGWY2IzDsLGCJmS0C7gQmO10hOSa1Tvbz8IV5/HBgR25+fSl3vLNSF7sWiRE6l4sEVV1Ty3UvfMYLC4qZNi6H304YoBOFiUQBnctFDpo/wcdfzhpMeoqfhz9YS2l5FX/68WASfCp1kWilQpdG+XzG70/NJbNVInf8eyWl5dX8bfJQkv0JXkcTkSB0DlY5IDPjv07sx28nDODNJZu5ZFYB+yqrm36hiESc1tClWS455ggyUhK5/sXFnPvgJxzX79CPI/D7jAGdMxjRsy1tWyeFMKVIy6ZCl2Y7Z2R30lP8/PL5xSzcsCsk79k7uzUje7VjRM+2jOzVjp7tU/Xlq8gh0l4uEnHlVTUsLt7N/HU7KFy/k4J1O9hTXrcZJystibye7cjr1ZYRPdsysEsmSX5tGRT5ivZykaiSkphAfk478nPaAVBb61hVUkbBurpyL1i/k7eKNgfG+hjSrU3dWnyvtgzv0ZbMVrrSkkgwWkOXqLR1TzkF63fWlfz6HRRt3ENNrcMMBnTK4NYzBzOoW6bXMUUiTheJlpi3r7KahV/somD9TmbP+4LKmlpevHwsPdqneh1NJKIOVOjaOCkxITXJz5g+WVz9g748cckoqmsdUx+dx/ayCq+jiUQNFbrEnN7ZaTw8NY+Nu/YzbVYB+ytrvI4kEhVU6BKTRvRsxx2Th7GoeBdXPf0p1TW1XkcS8ZwKXWLWyUd14qbTBvLOsi38/tUinRVSWjzttigx7YKje7FxVzn3v7eaLm1accX3+ngdScQzKnSJef/9wyPZvHs/f3l7BZ0yUjhzRDevI4l4QoUuMc/nM/581hBKyiq47oXFZKcnc+xhnGtGJFZpG7rEhSS/j/vOG0GfDmlc/mQhS77c7XUkkYhToUvcyEhJZNbF+WS2SuSix+azYcc+ryOJRJQKXeJKx4wUZl2cT0VVDRc+Oo9d+yq9jiQSMSp0iTt9O6bz4AV5bNixn0tmFVBepQOPpGVQoUtcGnVEe26fNJTCL3by89kLqanVPuoS/1ToErcmDO7Mbyfk8lbRZv7w+lIdeCRxT7stSlybNi6HTbv289AHa+nSJoXpx/b2OpJI2KjQJe79+pQBbN5Tzi1vLKdjRgoTh3b1OpJIWKjQJe75fMZt5wyhpLSCXzy3iOz0ZMb0zvI6lkjIaRu6tAjJ/gRmXpBHTlZrLnu8kNUlZV5HEgk5Fbq0GJmtEnnsonx8PuN3Ly/Rl6QSd1To0qJ0adOKX5zUj49Wb+cfn23yOo5ISKnQpcU5d1RPBnbJ4I+vL2NvRbXXcURCRoUuLU6Cz7h54lFs3lPOXf+3yus4IiHTrEI3s5PNbIWZrTKz6w8wbqSZ1ZjZWaGLKBJ6I3q25ewR3Xj4gzX6glTiRpOFbmYJwD3AeCAXmGJmuY2MuxV4O9QhRcLhuvH9aZWYwI26fJ3EieasoecDq5xza5xzlcBsYGKQcVcBLwBbQ5hPJGyy0pK59qQjeX/lNt5astnrOCKHrTmF3hXYUO9xceC5r5lZV+AM4P4DvZGZTTezAjMrKCkpOdisIiH3k1E9GNA5gz+8vpR9lfqCVGJbcwrdgjzX8O/TvwHXOecOeJ5S59xM51yecy4vO1uXCBPv+RN8/GHiQDbuLudufUEqMa45hV4MdK/3uBuwscGYPGC2ma0DzgLuNbPTQ5JQJMzyerXjx8O78uD7a1ijL0glhjWn0OcDfc0sx8ySgMnAq/UHOOdynHO9nHO9gOeBnzrnXg55WpEw+dX4AaT4E7jxNZ1mV2JXk4XunKsGrqRu75VlwLPOuSIzm2FmM8IdUCQSstOT+a8T+zHn8xLeLtridRyRQ2JerY3k5eW5goICT+YtEkx1TS0/uusDSsureeea42iVlOB1JJHvMLNC51xesGk6UlQkwJ/g46bTBvLlrv3c+66+IJXYo0IXqWfUEe05fWgXHnhvDeu27fU6jshBUaGLNPDrUwaQ5Pdx02s6glRiiwpdpIEOGSn8/IS+/GdFCe8s04HPEjtU6CJBTB3Ti34d07jptSLKqw54vJxI1FChiwSRmODjptOOonjnfu57d7XXcUSaRYUu0oije7fn1CFduO+91XyxfZ/XcUSapEIXOYDfnDKARJ9x8+tFXkcRaZIKXeQAOmWmcPUP+vLOsq38e5mOIJXopkIXacJFY3Pond2am15bqi9IJaqp0EWakOT3cfPEo/hixz4eeG+N13FEGqVCF2mGsX2ymDC4M/e+u4oNO/QFqUQnFbpIM/12wgASfMbNry/1OopIUH6vA4jEis6Zrbjq+3259a3lnHT7e3Rt04ougZ9v7qfQMSOFxAStK0nkqdBFDsK0cTnsr6ph6cY9bNy1n4UbdrFzX9W3xvgMOqSn0KVNyrfKvnNm3eNubVvRJjXJo99A4pkKXeQgJPl9XHNiv289t6+ymo27ytm4a/83P7vrHi/5cjf/LNpCZU3tt15z6pAu/G7CADpkpEQyvsQ5FbrIYUpN8tOnQxp9OqQFnV5b69i+t5KNu/azafd+Pt2wi0c/WMe7y7dy7Un9OG90T/zaRCMhoCsWiXhg7ba93PDKEt5fuY2BXTL44+lHMaxHW69jSQzQFYtEokxOVmsevzife84dzrayCn5830f86sXP2LWv0utoEsNU6CIeMTMmDO7Mv689nmljc3i2YAPfv+09nivYoAtryCFRoYt4LC3Zz29/lMtrV46jV/tUfvn8Ys55YC7LN+/xOprEGBW6SJTI7ZLB8zPGcOuZg1i5tYwJd37ALW8sY29FtdfRJEao0EWiiM9nTBrZg/+79njOGt6NmXPWcML/vsebn23SZhhpkgpdJAq1a53ErWcN5oXLjyazVSKXP7WAix6bz/rte72OJlFMhS4SxUb0bMfrV43jdz/KZf7aHZx4+xzueGclldW1Tb9YWhwVukiU8yf4mDYuh39fezwn5nbk9nc+59LHC3RudvkOFbpIjOiUmcI95w7nTz8exJyVJUybNZ99lfrCVL6hQheJMVPye/CXs4Ywd/V2Lnx0PmXaC0YCVOgiMeisEd342+RhFK7fyQUPf8Ke8qqmXyRxT4UuEqNOG9KFu6cMY3Hxbs576BOdNkBU6CKxbPygztx/3giWbyrl3Ac/YcdelXpL1qxCN7OTzWyFma0ys+uDTJ9oZovNbKGZFZjZuNBHFZFgTsjtyINT81hdUsbkmXMpKa3wOpJ4pMlCN7ME4B5gPJALTDGz3AbD/g0Mcc4NBS4GHgp1UBFp3HH9snn0wpFs2LGfSTPnsnl3udeRxAPNWUPPB1Y559Y45yqB2cDE+gOcc2Xum+OSWwM6Rlkkwsb0yWLWxfls2V3OpJlz+XLXfq8jSYQ1p9C7AhvqPS4OPPctZnaGmS0H/kHdWvp3mNn0wCaZgpKSkkPJKyIHkJ/TjicuGcWOvZVMemAuG3bs8zqSRFBzCt2CPPedNXDn3EvOuf7A6cAfgr2Rc26mcy7POZeXnZ19cElFpFmG92jL3y8ZTWl5Nec8MJe123T+l5aiOYVeDHSv97gbsLGxwc65OUBvM8s6zGwicogGdcvk6UtHU1Fdy6QH5rJqa6nXkSQCmlPo84G+ZpZjZknAZODV+gPMrI+ZWeD+cCAJ2B7qsCLSfLldMpg9fTS1DiY98LEumNECNFnozrlq4ErgbWAZ8KxzrsjMZpjZjMCwM4ElZraQuj1iJjmdvFnEc/06pvPMZaPxJxhTZn7Mki93ex1Jwsi86t28vDxXUFDgybxFWpr12/dy7oOfUFpexePTRjG0exuvI8khMrNC51xe0GkqdJGWoXjnvq+PJj22XxbJ/gRSEn0k+xNITvSRcoDblMRv7qcmJ5DTvjU+X7D9JSTcDlTo/kiHERFvdGubyjOXjea6Fz7j8y1llFfVUFFd+/XtwVw045i+Wdw+aShZaclhTCwHS2voIgJAba2jsqaWiqpayqtrgt9W1bC6pIzb/vU5bVMTuWvKcPJz2nkdvUXRGrqINMnnM1J8dZtXMklsdNwJdGRc3yyueGoBUx78mF/+8EimH3OENsFEAZ1tUUQO2sAumbx21Th+OLAj//Pmci59vECn740CKnQROSTpKYncc+5wbjw1lzkrS5hw5wcs3LDL61gtmgpdRA6ZmXHh2ByemzEGgLPv/4hHP1yLDkPxhgpdRA7b0O5t+MfV4zi2bzY3vbaUK/6+QJfF84AKXURCok1qEg9ekMevxvfn7aItnHbXBxRt1JGpkaRCF5GQ8fmMy47rzezpo9lfVcMZ937E0/O+0CaYCFGhi0jIjezVjjeuPoZROe341Yufcc2zi9hbUe11rLinQheRsGiflsxjF+VzzYn9eHnhl0y850NWbtFpfMNJhS4iYZPgM67+QV+enDaKXfsqOe3uD3np02KvY8UtFbqIhN3YPlm8cfUxDOqWyX89s4hb3lim7ephoEIXkYjokJHC3y8ZxfmjezJzzhp+8/ISamtV6qGkc7mISMT4E3zcPHEgaSl+7nt3Nfsra/jLWYPxJ2jdMhRU6CISUWbGdSf3Jy3Zz1/eXsG+ymrunDKMZH+C19Finv5ZFBFPXPG9Pvz+1FzeLtrCJbMK2F9Z43WkmKdCFxHPXDQ2hz+fOZgPVm1j6iPzKNXpAg6LCl1EPHXOyO7cOXkYC77YyU8e+oSde3Ua3kOlQhcRz506pAv3nzeC5ZtLmTzzY7aWlnsdKSap0EUkKpyQ25FHLxzJhp37mPTAx3y5a7/XkWKOCl1EosbYPlk8MS2fbWUVnHP/XNZu2+t1pJiiQheRqDKiZzuevrTubI3nPDCXFZt1/pfmUqGLSNQ5qmsmz0wfjc9g0sy5LC7Wpe2aQ4UuIlGpb8d0nrtsDGnJfs598BPmrd3hdaSop0IXkajVo30qz804mg4ZyVzwyCfM+bzE60hRTYUuIlGtc2Yrnr3saHKy0rhkVgFvF232OlLU0rlcRCTqZaUlM/vS0Ux9dB4/fWoBI3q0JS3FT1qyn7QUP+nJdffTU/ykpSR+c7/+9BQ/rRITMDOvf52wUaGLSEzITE3kyUtGccsby1hbspeS0grWbttLaXk1peVVVFTXNvkePoOMVolMye/BtSf2i7uzPDar0M3sZOAOIAF4yDn3Pw2m/wS4LvCwDLjcObcolEFFRNKS/dxyxqCg0yqra9lbUU1ZRTWl5XW3ZRVVgcIPPC6vZnVJGfe9u5rCdTu569xhdMxIifBvET5NFrqZJQD3ACcCxcB8M3vVObe03rC1wHHOuZ1mNh6YCYwKR2ARkWCS/D6S/Em0bZ3U5NiXP/2SX7/0Gafc8T53TB7GuL5ZEUgYfs35eyMfWOWcW+OcqwRmAxPrD3DOfeSc2xl4+DHQLbQxRURC5/RhXXn1yrG0T0vi/Ec+4W/vfE5NHFw9qTmF3hXYUO9xceC5xkwD3jycUCIi4danQzovXzGWM4Z15W/vrGTqI/PYVlbhdazD0pxCD/aVcNB/yszse9QV+nWNTJ9uZgVmVlBSov1JRcRbqUl+bjt7CLeeOYj563Yw4c73Y/oApuYUejHQvd7jbsDGhoPMbDDwEDDRObc92Bs552Y65/Kcc3nZ2dmHkldEJKTMjEkje/DST8eSmuRnyoMfc/97q2PyAtbNKfT5QF8zyzGzJGAy8Gr9AWbWA3gRON8593noY4qIhFdulwxevXIsJw/sxP+8uZxLHy9g177YuthGk4XunKsGrgTeBpYBzzrnisxshpnNCAy7AWgP3GtmC82sIGyJRUTCJD0lkbvPHcZNpw1kzsoSJtz5AQs3xM6Jwcw5b/6syMvLcwUF6n0RiU4LN+ziiqcWsLW0nN+cMoCpY3pFxVGmZlbonMsLNi2+DpMSEQmRod3b8I+rx3Fs32xufG0pV/7906i/iLUKXUSkEW1Sk3jwgjx+Nb4/bxVt5tS7PmDpxj1ex2qUCl1E5AB8PuOy43p/fRWlM+79MGrP+KhCFxFphvycdvzj6mPo3zmDn89eyLJN0bemrkIXEWmmrLRkHjx/BOkpfqY/EX27NarQRUQOQoeMFO4/fwRbdldw1dOfUl3T9Gl7I0WFLiJykIb3aMsfTz+K91du489vr/A6ztd0gQsRkUNwzsjuLNm4m5lz1jCwSwYThx7onIWRoTV0EZFD9Lsf5ZLfqx3//fxilny52+s4KnQRkUOVmODjnp8Mp13rJC57opDtHp9+V4UuInIYstOTeeD8EZSUVXDl3z+lysMvSVXoIiKHaXC3NvzpjEHMXbOdW95Y5lkOfSkqIhICZ47oRtHGPTzy4VoGdsnkrBGRvxKn1tBFRELk16f0Z0zv9vz6pc9Y5MFpd1XoIiIh4k/wcfe5w8lOS+ayJwopKY3sl6QqdBGREGrXOomZF4xg1/5KfvpUIZXVkfuSVIUuIhJiA7tkcuuZg5m/bid/eH1pxOarL0VFRMJg4tCuLN24hwcCR5JOzu8R9nlqDV1EJEz+++T+HNM3ixteKaJw/c6wz0+FLiISJgk+464pw+iUmcLlTxayZU95WOenQhcRCaM2qXVfkpZVVDPjyUIqqmvCNi8VuohImPXvlMFfzx7Cp1/s4vevFOGcC8t8VOgiIhFwyqDOXPG93syev4GnPvkiLPPQXi4iIhFyzYlH8sWO/XRITw7L+6vQRUQi5KsvScNFm1xEROKECl1EJE6o0EVE4oQKXUQkTqjQRUTihApdRCROqNBFROKECl1EJE5YuM4p0OSMzUqA9Yf48ixgWwjjhFq054Poz6h8h0f5Dk805+vpnMsONsGzQj8cZlbgnMvzOkdjoj0fRH9G5Ts8ynd4oj1fY7TJRUQkTqjQRUTiRKwW+kyvAzQh2vNB9GdUvsOjfIcn2vMFFZPb0EVE5LtidQ1dREQaUKGLiMSJqC50MzvZzFaY2Sozuz7IdDOzOwPTF5vZ8Ahm625m/zGzZWZWZGY/CzLmeDPbbWYLAz83RCpfYP7rzOyzwLwLgkz3cvkdWW+5LDSzPWb28wZjIr78zOwRM9tqZkvqPdfOzP5lZisDt20bee0BP69hzPcXM1se+G/4kpm1aeS1B/w8hDHfjWb2Zb3/jqc08lqvlt8z9bKtM7OFjbw27MvvsDnnovIHSABWA0cAScAiILfBmFOANwEDRgOfRDBfZ2B44H468HmQfMcDr3u4DNcBWQeY7tnyC/LfejN1B0x4uvyAY4HhwJJ6z/0ZuD5w/3rg1kZ+hwN+XsOY7yTAH7h/a7B8zfk8hDHfjcAvmvEZ8GT5NZh+G3CDV8vvcH+ieQ09H1jlnFvjnKsEZgMTG4yZCDzu6nwMtDGzzpEI55zb5JxbELhfCiwDukZi3iHk2fJr4AfAaufcoR45HDLOuTnAjgZPTwRmBe7PAk4P8tLmfF7Dks8590/nXHXg4cdAt1DPt7kaWX7N4dny+4qZGXAO8HSo5xsp0VzoXYEN9R4X893CbM6YsDOzXsAw4JMgk482s0Vm9qaZDYxoMHDAP82s0MymB5keFcsPmEzj/xN5ufy+0tE5twnq/iEHOgQZEy3L8mLq/uoKpqnPQzhdGdgk9Egjm6yiYfkdA2xxzq1sZLqXy69ZornQLchzDfexbM6YsDKzNOAF4OfOuT0NJi+gbjPCEOAu4OVIZgPGOueGA+OBK8zs2AbTo2H5JQGnAc8Fmez18mVVAMoAAAIBSURBVDsY0bAsfwNUA081MqSpz0O43Af0BoYCm6jbrNGQ58sPmMKB1869Wn7NFs2FXgx0r/e4G7DxEMaEjZklUlfmTznnXmw43Tm3xzlXFrj/BpBoZlmRyuec2xi43Qq8RN2ftfV5uvwCxgMLnHNbGk7wevnVs+WrTVGB261Bxnj9WZwK/Aj4iQts8G2oGZ+HsHDObXHO1TjnaoEHG5mv18vPD/wYeKaxMV4tv4MRzYU+H+hrZjmBtbjJwKsNxrwKXBDYW2M0sPurP43DLbC97WFgmXPufxsZ0ykwDjPLp255b49QvtZmlv7Vfeq+OFvSYJhny6+eRteKvFx+DbwKTA3cnwq8EmRMcz6vYWFmJwPXAac55/Y1MqY5n4dw5av/vcwZjczXs+UXcAKw3DlXHGyil8vvoHj9reyBfqjbC+Nz6r79/k3guRnAjMB9A+4JTP8MyItgtnHU/Um4GFgY+DmlQb4rgSLqvrH/GBgTwXxHBOa7KJAhqpZfYP6p1BV0Zr3nPF1+1P3jsgmoom6tcRrQHvg3sDJw2y4wtgvwxoE+rxHKt4q67c9ffQ7vb5ivsc9DhPI9Efh8LaaupDtH0/ILPP/YV5+7emMjvvwO90eH/ouIxIlo3uQiIiIHQYUuIhInVOgiInFChS4iEidU6CIicUKFLiISJ1ToIiJx4v8B+HoLaCtVelUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "plt.plot([confidence[rule[0]] for rule in sorted_confidence])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_iris\n",
    "dataset = load_iris()\n",
    "X = dataset.data\n",
    "y = dataset.target\n",
    "# print(dataset.DESCR)\n",
    "n_samples, n_features = X.shape\n",
    "attribute_means = X.mean(axis=0)\n",
    "assert attribute_means.shape == (n_features,)\n",
    "X_d = np.array(X >= attribute_means, dtype='int') #X_dicretized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from operator import itemgetter\n",
    "def train(X, y_true, feature):\n",
    "    n_samples, n_features = X.shape\n",
    "    assert 0 <= feature < n_features\n",
    "    values = set(X[:,feature])\n",
    "    predictors = dict()\n",
    "    errors = []\n",
    "    for current_value in values:\n",
    "        most_frequent_class, error = train_feature_value(X, y_true, feature, current_value)\n",
    "        predictors[current_value] = most_frequent_class\n",
    "        errors.append(error)\n",
    "    total_error = sum(errors)\n",
    "    return predictors, total_error\n",
    "def train_feature_value (x, y_true, feature, value):\n",
    "    class_counts = defaultdict(int)\n",
    "    for sample, y in zip(x, y_true):\n",
    "        if sample[feature] == value: #value? empty\n",
    "            class_counts[y] += 1\n",
    "    sorted_class_counts = sorted(class_counts.items(), key = itemgetter(1), reverse = True)\n",
    "    most_frequent_class = sorted_class_counts[0][0]\n",
    "    n_samples = X.shape[1]\n",
    "    error = sum([class_count for class_value, class_count in class_counts.items() if class_value != most_frequent_class])\n",
    "    return most_frequent_class, error\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are (112,) training samples\n",
      "There are (38,) testing samples\n"
     ]
    }
   ],
   "source": [
    "# Now, we split into a training and test set\n",
    "from sklearn.model_selection import train_test_split\n",
    "# from sklearn.cross_validation import train_test_split\n",
    "\n",
    "# Set the random state to the same number to get the same results as in the book\n",
    "random_state = 14\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X_d, y, random_state=random_state)\n",
    "print(\"There are {} training samples\".format(y_train.shape))\n",
    "print(\"There are {} testing samples\".format(y_test.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The best model is based on variable 2 and has error 37.00\n",
      "{'variable': 2, 'predictor': {0: 0, 1: 2}}\n"
     ]
    }
   ],
   "source": [
    "# Compute all of the predictors\n",
    "all_predictors = {variable: train(X_train, y_train, variable) for variable in range(X_train.shape[1])}\n",
    "errors = {variable: error for variable, (mapping, error) in all_predictors.items()}\n",
    "# Now choose the best and save that as \"model\"\n",
    "# Sort by error\n",
    "best_variable, best_error = sorted(errors.items(), key=itemgetter(1))[0]\n",
    "print(\"The best model is based on variable {0} and has error {1:.2f}\".format(best_variable, best_error))\n",
    "\n",
    "# Choose the bset model\n",
    "model = {'variable': best_variable, # the best obvious feature\n",
    "         'predictor': all_predictors[best_variable][0]} # the best feature's predictor, get predictor without error\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 2 2 2 0 2 0 2 2 0 2 2 0 2 0 2 2 2 0 0 0 2 0 2 0 2 2 0 0 0 2 0 2 0 2\n",
      " 2]\n"
     ]
    }
   ],
   "source": [
    "def predict(X_test, model):\n",
    "    variable = model['variable']\n",
    "    predictor = model['predictor']\n",
    "    y_predicted = np.array([predictor [int(sample[variable])] for sample in X_test]) # use obvious feature to predict\n",
    "    return y_predicted\n",
    "y_predicted = predict(X_test, model)\n",
    "print(y_predicted)\n",
    "# but there are three different species on the category, each predictor can just predict two species\n",
    "# we can't say that the \"other categorical species are similar\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: ({0: 0, 1: 2}, 41), 1: ({0: 1, 1: 0}, 58), 2: ({0: 0, 1: 2}, 37), 3: ({0: 0, 1: 2}, 37)}\n",
      "{0: 41, 1: 58, 2: 37, 3: 37}\n",
      "[[0 1 0 0]\n",
      " [0 1 0 0]\n",
      " [0 1 0 0]\n",
      " [0 0 1 1]\n",
      " [0 0 1 1]]\n",
      "[0 0 0 1 2]\n"
     ]
    }
   ],
   "source": [
    "print(all_predictors)\n",
    "# four features with 2 different possible values\n",
    "# 0(feature): ({0(feature valune):0, 1(other category):2},41(rate of wrong cal))\n",
    "print(errors)\n",
    "print(X_test[:5])\n",
    "print(y_test[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The test accuracy is 65.8%\n"
     ]
    }
   ],
   "source": [
    "# Compute the accuracy by taking the mean of the amounts that y_predicted is equal to y_test\n",
    "accuracy = np.mean(y_predicted == y_test) * 100\n",
    "print(\"The test accuracy is {:.1f}%\".format(accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.94      1.00      0.97        17\n",
      "           1       0.00      0.00      0.00        13\n",
      "           2       0.40      1.00      0.57         8\n",
      "\n",
      "    accuracy                           0.66        38\n",
      "   macro avg       0.45      0.67      0.51        38\n",
      "weighted avg       0.51      0.66      0.55        38\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "S:\\Anaconda\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(y_test, y_predicted))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "config Completer.use_jedi = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "oldHeight": 122.4,
   "position": {
    "height": "40px",
    "left": "1070px",
    "right": "20px",
    "top": "148px",
    "width": "350px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "varInspector_section_display": "none",
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
